# -*- coding: utf-8 -*-
# Copyright 2019 The TensorFlow Probability Authors and Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Norm Modules."""

import warnings

import tensorflow as tf


class WeightNormalization(tf.keras.layers.Wrapper):
    """Layer wrapper to decouple magnitude and direction of the layer's weights.
    This wrapper reparameterizes a layer by decoupling the weight's
    magnitude and direction. This speeds up convergence by improving the
    conditioning of the optimization problem. It has an optional data-dependent
    initialization scheme, in which initial values of weights are set as functions
    of the first minibatch of data. Both the weight normalization and data-
    dependent initialization are described in [Salimans and Kingma (2016)][1].
    #### Example
    ```python
      net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'),
             input_shape=(32, 32, 3), data_init=True)(x)
      net = WeightNorm(tf.keras.layers.Conv2DTranspose(16, 5, activation='relu'),
                       data_init=True)
      net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'),
                       data_init=True)(net)
      net = WeightNorm(tf.keras.layers.Dense(num_classes),
                       data_init=True)(net)
    ```
    #### References
    [1]: Tim Salimans and Diederik P. Kingma. Weight Normalization: A Simple
         Reparameterization to Accelerate Training of Deep Neural Networks. In
         _30th Conference on Neural Information Processing Systems_, 2016.
         https://arxiv.org/abs/1602.07868
    """

    def __init__(self, layer, data_init=True, **kwargs):
        """Initialize WeightNorm wrapper.
        Args:
          layer: A `tf.keras.layers.Layer` instance. Supported layer types are
            `Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs
            are not supported.
          data_init: `bool`, if `True` use data dependent variable initialization.
          **kwargs: Additional keyword args passed to `tf.keras.layers.Wrapper`.
        Raises:
          ValueError: If `layer` is not a `tf.keras.layers.Layer` instance.
        """
        if not isinstance(layer, tf.keras.layers.Layer):
            raise ValueError(
                "Please initialize `WeightNorm` layer with a `tf.keras.layers.Layer` "
                "instance. You passed: {input}".format(input=layer)
            )

        layer_type = type(layer).__name__
        if layer_type not in [
            "Dense",
            "Conv2D",
            "Conv2DTranspose",
            "Conv1D",
            "GroupConv1D",
        ]:
            warnings.warn(
                "`WeightNorm` is tested only for `Dense`, `Conv2D`, `Conv1D`, `GroupConv1D`, "
                "`GroupConv2D`, and `Conv2DTranspose` layers. You passed a layer of type `{}`".format(
                    layer_type
                )
            )

        super().__init__(layer, **kwargs)

        self.data_init = data_init
        self._track_trackable(layer, name="layer")
        self.filter_axis = -2 if layer_type == "Conv2DTranspose" else -1

    def _compute_weights(self):
        """Generate weights with normalization."""
        # Determine the axis along which to expand `g` so that `g` broadcasts to
        # the shape of `v`.
        new_axis = -self.filter_axis - 3

        self.layer.kernel = tf.nn.l2_normalize(
            self.v, axis=self.kernel_norm_axes
        ) * tf.expand_dims(self.g, new_axis)

    def _init_norm(self):
        """Set the norm of the weight vector."""
        kernel_norm = tf.sqrt(
            tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes)
        )
        self.g.assign(kernel_norm)

    def _data_dep_init(self, inputs):
        """Data dependent initialization."""
        # Normalize kernel first so that calling the layer calculates
        # `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]).
        self._compute_weights()

        activation = self.layer.activation
        self.layer.activation = None

        use_bias = self.layer.bias is not None
        if use_bias:
            bias = self.layer.bias
            self.layer.bias = tf.zeros_like(bias)

        # Since the bias is initialized as zero, setting the activation to zero and
        # calling the initialized layer (with normalized kernel) yields the correct
        # computation ((5) in Salimans and Kingma (2016))
        x_init = self.layer(inputs)
        norm_axes_out = list(range(x_init.shape.rank - 1))
        m_init, v_init = tf.nn.moments(x_init, norm_axes_out)
        scale_init = 1.0 / tf.sqrt(v_init + 1e-10)

        self.g.assign(self.g * scale_init)
        if use_bias:
            self.layer.bias = bias
            self.layer.bias.assign(-m_init * scale_init)
        self.layer.activation = activation

    def build(self, input_shape=None):
        """Build `Layer`.
        Args:
          input_shape: The shape of the input to `self.layer`.
        Raises:
          ValueError: If `Layer` does not contain a `kernel` of weights
        """
        if not self.layer.built:
            self.layer.build(input_shape)

            if not hasattr(self.layer, "kernel"):
                raise ValueError(
                    "`WeightNorm` must wrap a layer that"
                    " contains a `kernel` for weights"
                )

            self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims))
            self.kernel_norm_axes.pop(self.filter_axis)

            self.v = self.layer.kernel

            # to avoid a duplicate `kernel` variable after `build` is called
            self.layer.kernel = None
            self.g = self.add_weight(
                name="g",
                shape=(int(self.v.shape[self.filter_axis]),),
                initializer="ones",
                dtype=self.v.dtype,
                trainable=True,
            )
            self.initialized = self.add_weight(
                name="initialized", dtype=tf.bool, trainable=False
            )
            self.initialized.assign(False)

        super().build()

    def call(self, inputs):
        """Call `Layer`."""
        if not self.initialized:
            if self.data_init:
                self._data_dep_init(inputs)
            else:
                # initialize `g` as the norm of the initialized kernel
                self._init_norm()

            self.initialized.assign(True)

        self._compute_weights()
        output = self.layer(inputs)
        return output

    def compute_output_shape(self, input_shape):
        return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())
